Skip to content

Implement a new api that will be switching between asm and hip pa internally#1821

Merged
valarLip merged 12 commits intomainfrom
common_hip_asm_pa_inerface
Jan 16, 2026
Merged

Implement a new api that will be switching between asm and hip pa internally#1821
valarLip merged 12 commits intomainfrom
common_hip_asm_pa_inerface

Conversation

@JohnNikolay84
Copy link
Contributor

@JohnNikolay84 JohnNikolay84 commented Jan 12, 2026

Inference engines should be calling paged_attention_common now with shuffled kv cache layout and aiter internally will decide between asm or hip kernel. HIP is more performant for lower concurrencies ( < 128). Also test_pa.py unit test has been updated to include the new interface.

Since asm and hip kernels require scales in different layouts the client of this api will be expected to provide both ( a scale value for hip per k/v tensor and an expanded per block scale for asm ).

Motivation

vLLM has had a binary choice between HIP ( c++ ) and asm kernels provided by aiter framework. HIP kernel is more performant on lower concurrencies, asm kernel is more performant on higher concurrencies. It would be great if Aiter could pick the best depending on the input.

Technical Details

HIP pa kernel under csrc/cpp_itfs will now support a variant when kv_cache is provided as 5D tensor (asm style)
Aiter internally will switch between HIP and asm based on expected concurrency.
Inference engines should be calling new API with asm shuffled kv cache layouts - paged_attention_common, but also provide a scratch space in case a 2-stage HIP kernel will be selected.
No legacy api is affected.

Test Plan

Run qwen3 model and make sure accuracy has not regressed and performance is improved
Create a new aiter unit test for the new api and make sure it pass.

Test Result

25% performance improvement for CONC > 64 on qwen3
8% overall improvement for Qwen_Qwen3-235B-A22B in geomean-OTPS

image

Inference engines should be calling paged_attention_common now with
shuffled kv cache layout and aiter internally will decide between asm
or hip kernel. HIP is more performant for lower concurrencies ( < 128).
Also a unit test has been updated to include the new interface.

Note that support for the shuffled scales in HIP is not supported and is
always redirected to asm now when KV cache is  in int8 or fp8 formats.
@JohnNikolay84 JohnNikolay84 requested review from a team and valarLip January 12, 2026 16:26
@JohnNikolay84 JohnNikolay84 requested review from fsx950223 and removed request for fsx950223 January 12, 2026 22:24
@fsx950223
Copy link
Contributor

@copilot Is the PR pass the page attention unit test?

Copy link
Contributor

Copilot AI commented Jan 15, 2026

@fsx950223 I've opened a new pull request, #1848, to work on those changes. Once the pull request is ready, I'll request review from you.

@valarLip valarLip merged commit 22d1e35 into main Jan 16, 2026
17 checks passed
@valarLip valarLip deleted the common_hip_asm_pa_inerface branch January 16, 2026 02:53
@JohnNikolay84 JohnNikolay84 self-assigned this Jan 22, 2026
yzhou103 pushed a commit that referenced this pull request Jan 28, 2026
…ernally (#1821)

* Implement a new api that will be switching between asm and hip pa

Inference engines should be calling paged_attention_common now with
shuffled kv cache layout and aiter internally will decide between asm
or hip kernel. HIP is more performant for lower concurrencies ( < 128).
Also a unit test has been updated to include the new interface.

Note that support for the shuffled scales in HIP is not supported and is
always redirected to asm now when KV cache is  in int8 or fp8 formats.

* Delete op_tests/README_pa_merged_tests.md

* Delete op_tests/test_pa_merged.py

* Fix formatting according to Black requirements

* Fix one last place with broken formatting

* Remove modification to pa_v1, we already have pa for 5D kv cache

* Fix another formatting issue

* Add proper quant support for the common API

* Apply formatting

* Remove redundant parameters

* Remove redundant parameters

---------

Co-authored-by: Sergey Solo <ssolovye@amd.com>
Co-authored-by: Mikko Tukiainen <mikko.tukiainen@amd.com>
valarLip pushed a commit that referenced this pull request Mar 18, 2026
…ernally (#1821)

* Implement a new api that will be switching between asm and hip pa

Inference engines should be calling paged_attention_common now with
shuffled kv cache layout and aiter internally will decide between asm
or hip kernel. HIP is more performant for lower concurrencies ( < 128).
Also a unit test has been updated to include the new interface.

Note that support for the shuffled scales in HIP is not supported and is
always redirected to asm now when KV cache is  in int8 or fp8 formats.

* Delete op_tests/README_pa_merged_tests.md

* Delete op_tests/test_pa_merged.py

* Fix formatting according to Black requirements

* Fix one last place with broken formatting

* Remove modification to pa_v1, we already have pa for 5D kv cache

* Fix another formatting issue

* Add proper quant support for the common API

* Apply formatting

* Remove redundant parameters

* Remove redundant parameters

---------

Co-authored-by: Sergey Solo <ssolovye@amd.com>
Co-authored-by: Mikko Tukiainen <mikko.tukiainen@amd.com>
valarLip pushed a commit that referenced this pull request Mar 18, 2026
…ernally (#1821)

* Implement a new api that will be switching between asm and hip pa

Inference engines should be calling paged_attention_common now with
shuffled kv cache layout and aiter internally will decide between asm
or hip kernel. HIP is more performant for lower concurrencies ( < 128).
Also a unit test has been updated to include the new interface.

Note that support for the shuffled scales in HIP is not supported and is
always redirected to asm now when KV cache is  in int8 or fp8 formats.

* Delete op_tests/README_pa_merged_tests.md

* Delete op_tests/test_pa_merged.py

* Fix formatting according to Black requirements

* Fix one last place with broken formatting

* Remove modification to pa_v1, we already have pa for 5D kv cache

* Fix another formatting issue

* Add proper quant support for the common API

* Apply formatting

* Remove redundant parameters

* Remove redundant parameters

---------

Co-authored-by: Sergey Solo <ssolovye@amd.com>
Co-authored-by: Mikko Tukiainen <mikko.tukiainen@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants